package com.offcn.bigdata.streaming.p2

import org.apache.spark.SparkConf
import org.apache.spark.streaming.{Seconds, StreamingContext}

/**
  * 为了保障driver ha来对之前代码进行改写
  */
object _06DriverHA2UpdateStateByKeyOps {
    def main(args: Array[String]): Unit = {
        if(args == null || args.length != 4) {
            println(
                """
                  |Usage: <batchInterval> <host> <port> <checkpoint>
                """.stripMargin)
            System.exit(-1)
        }
        val Array(batchInterval, host, port, checkpoint) = args

        def creatingFunc(): StreamingContext = {
            val conf = new SparkConf()
                .setAppName("_04UpdateStateByKeyOps")

            val ssc = new StreamingContext(conf, Seconds(batchInterval.toLong))

            ssc.checkpoint(checkpoint)
            val lines = ssc.socketTextStream(host, port.toInt)

            val pairs = lines.flatMap(_.split("\\s+")).map((_, 1))

            val ret = pairs.updateStateByKey(updateFunc)

            ret.print

            ssc
        }
        val ssc = StreamingContext.getOrCreate(checkpoint, creatingFunc)

        ssc.start()
        ssc.awaitTermination()
    }

    /**
      * 状态更新函数
      * @param current      key所对应的当前批次的状态列表
      * @param history   key所对应的历史批次的状态，可能存在可能不存在
      * @return
      */
    def updateFunc(current: Seq[Int], history: Option[Int]): Option[Int] = {
        println(s"current: ${current}, history: ${history.getOrElse(0)}")
        val sum = current.sum + history.getOrElse(0)
        Option(sum)
    }
}
